import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch.optim.lr_scheduler as lr_scheduler
from torchvision import datasets, transforms
import time
from matplotlib import pyplot as plt

'''Prepare and Load Data Set'''
# pipline_train = transforms.Compose([
#     transforms.RandomCrop(32, padding=4),
#     transforms.RandomHorizontalFlip(), # random rotate figures
#     # transforms.Resize((32, 32)),       # modify the figure size to 32x32
#     transforms.ToTensor(),             # turn the figure to tensor type
#     transforms.Normalize((0.4914, 0.4822, 0.4465 ), (0.2023, 0.1994, 0.2010)) # normalize figures
# ])
pipline_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(), # random rotate figures
    transforms.RandomRotation(15),     # random rotation
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), # random change brightness, contrast, saturation and hue
    transforms.ToTensor(),             # turn the figure to tensor type
    transforms.Normalize((0.4914, 0.4822, 0.4465 ), (0.2023, 0.1994, 0.2010)) # normalize figures
])

pipline_test = transforms.Compose([
    # transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465 ), (0.2023, 0.1994, 0.2010))
])

train_set = torchvision.datasets.CIFAR10(root="../data", train=True, download=True, transform=pipline_train)
test_set  = torchvision.datasets.CIFAR10(root="../data", train=False, download=True, transform=pipline_test)

trainloader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True # num_workers=2
		)
testloader  = torch.utils.data.DataLoader(test_set, batch_size=100, shuffle=False # num_workers=2
		)

'''Construct LeNet Structure'''
class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.GELU(),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )
        self.gelu = nn.GELU()
    def forward(self, x):
        out = self.left(x)
        out = out + self.shortcut(x)
        out = self.gelu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, ResidualBlock, num_classes=10):
        super(ResNet, self).__init__()
        self.inchannel = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.GELU()
        )
        self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        # self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.conv2 = nn.Sequential(
          nn.Conv2d(256, 256, kernel_size=4, stride=4, bias=False),
          nn.BatchNorm2d(256),
          nn.GELU()
        )
        self.maxpool = nn.MaxPool2d(4, 4)
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(1024, num_classes)

    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        # out = self.layer4(out)
        out = self.conv2(out)
        out = self.dropout(out)
        # out = F.avg_pool2d(out, 4)
        # out = self.maxpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def ResNet18():
    return ResNet(ResidualBlock)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model = ResNet18().to(device)

optimizer = optim.Adam(model.parameters(), lr=0.005)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.5)

def train_runner(model, devive, trainloader, optimizer, epoch):
    model.train()
    total = 0
    correct = 0.0

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(devive), labels.to(devive)
        optimizer.zero_grad()
        outputs = model(inputs)

        loss = F.cross_entropy(outputs, labels)
        predict = outputs.argmax(dim=-1)
        total += labels.size(0)
        correct += (predict == labels).sum().item()

        loss.backward()
        optimizer.step()
        if i % 1000 == 0:
            print("Train Epoch{} \t Loss: {:.6f}, accuracy: {:.6f}%".format(epoch, loss.item(), 100*(correct/total)))
            Loss.append(loss.item())
            Accuracy.append(correct/total)

    file_path = '../model_params/float/ResNet18_param_' + str(epoch) + '.pth'
    torch.save(model.state_dict(), file_path)
    return loss.item(), correct/total

def test_runner(model, device, testloader):
    model.eval()
    correct = 0.0
    test_loss = 0.0
    total = 0

    with torch.no_grad():
        for data, label in testloader:
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, label).item()
            predict = output.argmax(dim=1)
            total += label.size(0)
            correct += (predict == label).sum().item()
        print("test_average_loss: {:.6f}, accuracy: {:.6f}%".format(test_loss/total, 100*(correct/total)))

epoch = 20
Loss = []
Accuracy = []

for epoch in range(1, epoch+1):
    print("Start_time", time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
    for name, param in model.named_parameters():
        print(f"Name: {name}, Parameter: {param.data}, Gradient: {param.grad}")
    loss, acc = train_runner(model, device, trainloader, optimizer, epoch)
    Loss.append(loss)
    Accuracy.append(acc)
    test_runner(model, device, testloader)
    scheduler.step()
    print("end_time: ", time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())), '\n')

print('Finished Training')
# 修改为保存图片
plt.subplot(2, 1, 1)
plt.plot(Loss)
plt.title('Loss')
plt.savefig('loss_plot.png')  # 保存 Loss 图像
plt.subplot(2, 1, 2)
plt.plot(Accuracy)
plt.title('Accuracy')
plt.savefig('accuracy_plot.png')  # 保存 Accuracy 图像
